f80a2f
@@ -259,6 +259,16 @@
public Object process(Node nd, Stack<Node> stack,
       context.currentMapJoinOperators.clear();
     }
 
+    // This is where we cut the tree as described above. We also remember that
+    // we might have to connect parent work with this work later.
+    for (Operator<?> parent : new ArrayList<Operator<?>>(root.getParentOperators())) {
+      if (LOG.isDebugEnabled()) {
+        LOG.debug("Removing " + parent + " as parent from " + root);
+      }
+      context.leafOperatorToFollowingWork.put(parent, work);
+      root.removeParent(parent);
+    }
+
     if (!context.currentUnionOperators.isEmpty()) {
       // if there are union all operators we need to add the work to the set
       // of union operators.
@@ -288,21 +298,6 @@
public Object process(Node nd, Stack<Node> stack,
       work = unionWork;
     }
 
-
-    // This is where we cut the tree as described above. We also remember that
-    // we might have to connect parent work with this work later.
-    boolean removeParents = false;
-    for (Operator<?> parent: new ArrayList<Operator<?>>(root.getParentOperators())) {
-      removeParents = true;
-      context.leafOperatorToFollowingWork.put(parent, work);
-      LOG.debug("Removing " + parent + " as parent from " + root);
-    }
-    if (removeParents) {
-      for (Operator<?> parent : new ArrayList<Operator<?>>(root.getParentOperators())) {
-        root.removeParent(parent);
-      }
-    }
-
     // We're scanning a tree from roots to leaf (this is not technically
     // correct, demux and mux operators might form a diamond shape, but
     // we will only scan one path and ignore the others, because the
@@ -350,19 +345,14 @@
public Object process(Node nd, Stack<Node> stack,
           // this can only be possible if there is merge work followed by the union
           UnionWork unionWork = (UnionWork) followingWork;
           int index = getFollowingWorkIndex(tezWork, unionWork, rs);
-          if (index != -1) {
-            BaseWork baseWork = tezWork.getChildren(unionWork).get(index);
-            if (baseWork instanceof MergeJoinWork) {
-              MergeJoinWork mergeJoinWork = (MergeJoinWork) baseWork;
-              // disconnect the connection to union work and connect to merge work
-              followingWork = mergeJoinWork;
-              rWork = (ReduceWork) mergeJoinWork.getMainWork();
-            } else {
-              rWork = (ReduceWork) baseWork;
-            }
+          BaseWork baseWork = tezWork.getChildren(unionWork).get(index);
+          if (baseWork instanceof MergeJoinWork) {
+            MergeJoinWork mergeJoinWork = (MergeJoinWork) baseWork;
+            // disconnect the connection to union work and connect to merge work
+            followingWork = mergeJoinWork;
+            rWork = (ReduceWork) mergeJoinWork.getMainWork();
           } else {
-            throw new SemanticException("Following work not found for the reduce sink: "
-                + rs.getName());
+            rWork = (ReduceWork) baseWork;
           }
         } else {
           rWork = (ReduceWork) followingWork;
@@ -406,17 +396,17 @@
public Object process(Node nd, Stack<Node> stack,
     return null;
   }
 
-  private int getFollowingWorkIndex(TezWork tezWork, UnionWork unionWork, ReduceSinkOperator rs) {
+  private int getFollowingWorkIndex(TezWork tezWork, UnionWork unionWork, ReduceSinkOperator rs) 
+      throws SemanticException {
     int index = 0;
     for (BaseWork baseWork : tezWork.getChildren(unionWork)) {
-      if (tezWork.getEdgeProperty(unionWork, baseWork).equals(TezEdgeProperty.EdgeType.CONTAINS)) {
-        index++;
-      } else {
+      TezEdgeProperty edgeProperty = tezWork.getEdgeProperty(unionWork, baseWork);
+      if (edgeProperty.getEdgeType() != TezEdgeProperty.EdgeType.CONTAINS) {
         return index;
       }
+      index++;
     }
-
-    return -1;
+    throw new SemanticException("Following work not found for the reduce sink: " + rs.getName());
   }
 
   @SuppressWarnings("unchecked")
